from torch import nn
from torch.utils.data import DataLoader

from config import *
from transformer import TransformerEncoder, TransformerDecoder, EncoderDecoder
from attention import sequence_mask
from data import loaddata, MyDataset, Lang
from train import loss_fn

import collections
import math


def bleu_weighted(pred_seq, label_seq, k):  # 1/2^n
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score


def bleu_seq(pred_seq, label_seq, k):  # 1/n
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), 1/n)
    return score


def bleu(pred_tokens, label_tokens, k):  # 1/n
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    if len_pred < k:
        return 0
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[str(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[str(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[str(pred_tokens[i: i + n])] -= 1
        score *= math.pow(num_matches / (len_pred - n + 1), 1/k)
    return score


def BLEU1_4(pred_seq, label_seq, b):  # b = [0, 0, 0, 0]
    for i in range(1,5):
        b[i-1] += bleu(pred_seq, label_seq, i)
    return b


def get_distinct(out_seq, n):
    seq_subs = []
    for i in range(len(out_seq)-n+1):
        seq_subs.append(str(out_seq[i: i + n]))

    if len(seq_subs) == 0:
        return 0
    else:
        return len(set(seq_subs)) / len(seq_subs)


def distinct_1_n(out_seq, n, d):
    for i in range(n):
        d[i] += get_distinct(out_seq, i+1)
    return d

def seq2seq_test(net, data_iter, lang, device):
    net.eval()
    b = [0, 0, 0, 0]
    d = [0, 0, 0]
    with torch.no_grad():  # 取消梯度
        for batch in data_iter:
            enc_X, enc_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]

            enc_outputs = net.encoder(enc_X, enc_valid_len)
            dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)

            dec_X = torch.tensor([2  # lang.word2index['<bos>']
                                ], device=device).unsqueeze(0)
            output_seq = []

            for _ in range(num_steps):
                Y1, dec_state = net.decoder(dec_X, dec_state)
                dec_X = Y1.argmax(dim=2)  # 最大的下标，即为最终词下标
                pred = dec_X.squeeze(dim=0).type(torch.int32).item()  # 去掉第一维取结果
                output_seq.append(pred)
                if pred == 3: # eos
                    break
            b = BLEU1_4(output_seq, Y[0:Y_valid_len].squeeze().tolist(), b)
            d = distinct_1_n(output_seq, 3, d)
            # print(b)

    return  [bi/len(data_iter) for bi in b], [di/len(data_iter) for di in d]


if __name__ == "__main__":
    # pred_seq = "i come from china , and you ?"
    # label_seq = "i come from china , how about you ?"
    # print(pred_seq.split())
    # lang = torch.load(DATA_ROOT + "dialog.lang")
    # lang = torch.load(DATA_ROOT + "dialog.lang")
    # testdata = torch.load(DATA_ROOT + "testdata")
    # testiter = DataLoader(testdata, 1, shuffle=False)
    # encoder = torch.load(MODEL_ROOT + "trans_encoder1.mdl", map_location=torch.device('cpu'))
    # decoder = torch.load(MODEL_ROOT + "trans_decoder1.mdl", map_location=torch.device('cpu'))
    #
    # net = EncoderDecoder(encoder, decoder)
    # b, d = seq2seq_test(net, testiter, lang, device)
    # print(b, d)


    pred_token = [1,2,1]
    label_token = [1,2,1,2,1]
    print(BLEU1_4(pred_token, label_token, [0,0,0,0]))
    print(distinct_1_n(pred_token, 3, [0,0,0]))



